{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([[0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 1., 0., ..., 0., 0., 0.],\n",
       "        ...,\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.],\n",
       "        [0., 0., 0., ..., 0., 0., 0.]]),\n",
       " array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
       "        1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
       "        0, 0, 0, 0, 0, 0]),\n",
       " (50, 820),\n",
       " (50,))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy as np\n",
    "import os\n",
    "import re\n",
    "\n",
    "\n",
    "def load_data():\n",
    "    files = []\n",
    "    y = []\n",
    "    for cate in os.listdir('邮件数据'):\n",
    "        for file in os.listdir('邮件数据/' + cate):\n",
    "            with open('邮件数据/' + cate + '/' + file) as fr:\n",
    "                #print('邮件数据/' + cate + '/' + file)\n",
    "                lines = []\n",
    "                for line in fr.readlines():\n",
    "                    line = line.strip()\n",
    "                    line = [\n",
    "                        word for word in re.split('\\W', line) if len(word) >= 2\n",
    "                    ]\n",
    "                    lines.extend(line)\n",
    "                files.append(lines)\n",
    "                y.append(1 if cate == '垃圾' else 0)\n",
    "\n",
    "    #print(files[0])\n",
    "\n",
    "    #求所有单词的集合\n",
    "    words = []\n",
    "    for lines in files:\n",
    "        words.extend(lines)\n",
    "    words = list(set(words))\n",
    "\n",
    "    #句子数字化,统计每个词在每个句子中出现的次数\n",
    "    x = np.zeros((len(files), len(words)))\n",
    "    for i in range(len(files)):\n",
    "        for j in range(len(files[i])):\n",
    "            if files[i][j] in words:\n",
    "                x[i, words.index(files[i][j])] += 1\n",
    "\n",
    "    #1是违规言论,0合法言论\n",
    "    y = np.array(y)\n",
    "    return x, y\n",
    "\n",
    "\n",
    "x, y = load_data()\n",
    "x, y, x.shape, y.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((820,), (820,), -0.6931471805599453, -0.6931471805599453)"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def train(x, y):\n",
    "\n",
    "    #首先求总体的违规率,等于 违规次数 / 总次数\n",
    "    p1 = y.sum() / len(y)\n",
    "    p0 = 1 - p1\n",
    "\n",
    "    #取对数概率\n",
    "    p1 = np.log(p1)\n",
    "    p0 = np.log(p0)\n",
    "\n",
    "    #根据y的值,切分x为正例和反例\n",
    "    x_1 = x[y == 1]\n",
    "    x_0 = x[y == 0]\n",
    "\n",
    "    #统计在正例中,所有词出现的次数,当然在一个句子中出现多次也只算1次\n",
    "    #也就是说,是每个词出现的句子数\n",
    "    #最后要加1是为了避免0的情况,也就是说,所有词最少出现1次\n",
    "    p1_given_word = x_1.sum(axis=0) + 1\n",
    "\n",
    "    #上面统计的是次数,除以总次数就等于概率了.\n",
    "    #也就是每个词,出现在正例句子中的概率\n",
    "    p1_given_word = p1_given_word / p1_given_word.sum()\n",
    "\n",
    "    #取对数概率\n",
    "    p1_given_word = np.log(p1_given_word)\n",
    "\n",
    "    #p0_given_word的计算同理\n",
    "    p0_given_word = x_0.sum(axis=0) + 1\n",
    "    p0_given_word = p0_given_word / p0_given_word.sum()\n",
    "    p0_given_word = np.log(p0_given_word)\n",
    "\n",
    "    #取对数概率\n",
    "    return p1_given_word, p0_given_word, p1, p0\n",
    "\n",
    "\n",
    "p1_given_word, p0_given_word, p1, p0 = train(x, y)\n",
    "p1_given_word.shape, p0_given_word.shape, p1, p0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#测试\n",
    "def pred(x):\n",
    "    #本来p1_given_x应该是每个单词属于p1的概率连乘, 但是因为前面取了对数, 所以这里求和就可以了\n",
    "    p1_given_x = x.dot(p1_given_word) + p1\n",
    "    p0_given_x = x.dot(p0_given_word) + p0\n",
    "\n",
    "    return 1 if p1_given_x > p0_given_x else 0\n",
    "\n",
    "\n",
    "correct = 0\n",
    "for xi, yi in zip(x, y):\n",
    "    if pred(xi) == yi:\n",
    "        correct += 1\n",
    "\n",
    "correct / len(x)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
